import unittest

import torch
from torchsummary import summary

from training.models.color_space import UnetColorSpaceGeneratorV2


class UnetColorSpaceGeneratorV2TestCase(unittest.TestCase):

    def test_unet_color_space(self):
        self.assertTrue(torch.cuda.is_available())
        model = UnetColorSpaceGeneratorV2(in_channels=9, out_channels=3)
        model = model.cuda()
        input_size = (3, 224, 224)
        summary(model, input_size=input_size)


if __name__ == '__main__':
    unittest.main()
